package cn.doitedu.ml.bayes

import cn.doitedu.commons.util.SparkUtil
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.NaiveBayesModel
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame

import scala.collection.mutable

/**
 * @date: 2020/2/20
 * @site: www.doitedu.cn
 * @author: hunter.d 涛哥
 * @qq: 657270652
 * @description:
  *    利用训练好的朴素贝叶斯模型，来对未知数据进行出轨概率预测
 */
object AmourBayesPredict {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)

    val spark = SparkUtil.getSparkSession(this.getClass.getSimpleName)


    // 1. 加载待预测数据集
    val test = spark.read.option("header","true").csv("userprofile/data/demo/bayes/test")


    // 2. 特征值处理
    // 特征要数值化
    val test2 = test.selectExpr(
      "name",
      "cast(case job when '老师' then 1.0 when '程序员' then 2.0 else 3.0 end as double) as job",
      "cast(case income when '低' then 1.0 when '中' then 2.0 else 3.0 end as double) as income",
      "cast(case age when '青年' then 1.0 when '中年' then 2.0 else 3.0 end as double) as age",
      "cast(case sex when '男' then 1.0 else 2.0 end as double) as sex"
    )
    // 特征向量化
    val to_vec = (arr:mutable.WrappedArray[Double])=>{Vectors.dense(arr.toArray)}
    spark.udf.register("to_vec",to_vec)
    val test3 = test2.selectExpr("name","to_vec(array(job,income,age,sex)) as features")

    // 3. 加载训练好的模型
    val model = NaiveBayesModel.load("userprofile/data/demo/bayes/model")


    // 4. 将处理好的带预测数据，输入模型的预测方法，得出预测结果
    val predict: DataFrame = model.transform(test3)


    predict.show(10,false)


    spark.close()


  }

}
